{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "356fda40",
   "metadata": {},
   "source": [
    "## 深度学习特征\n",
    "\n",
    "提取CT、MRI、内镜、Xray等影像数据的深度学习特征。\n",
    "\n",
    "### Onekey步骤\n",
    "\n",
    "1. 将待提取的数据转化成nii(nii.gz)，需要使用到OKT-convert2nii工具。\n",
    "2. 获取到指定目录的所有图像数据。\n",
    "3. 选择要提取什么样的模型的深度学习特征，目前Onekey支持ResNet3d深度学习模型。（可以考虑使用Onekey进行迁移学习）\n",
    "  > 只支持ResNet3d，是因为目前仅有resnet存在预训练的模型。\n",
    "4. 提取特征，保存特征文件。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ae13c54",
   "metadata": {},
   "outputs": [],
   "source": [
    "#### OKT-convert2nii\n",
    "from onekey_algo.custom.Manager import onekey_show\n",
    "onekey_show('OKT-convert2nii')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cdb5273e",
   "metadata": {},
   "source": [
    "### 使用crop max roi工具保存3dnii数据\n",
    "\n",
    "参数axis_3d > 2 即可使用保存roi的3d数据，不进行最大面积截断"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9b9db3f5",
   "metadata": {},
   "source": [
    "### 获取待提取特征的NII数据\n",
    "\n",
    "提供两种批量处理的模式：\n",
    "1. 目录模式，提取指定目录下的所有jpg文件的特征。\n",
    "2. 文件模式，待提取的数据存储在文件中，每行一个样本。\n",
    "\n",
    "当然也可以在最后自己指定手动提取指定若干文件。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae43e311",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from onekey_algo import OnekeyDS as okds\n",
    "# 目录模式\n",
    "mydir = r'数据目录' \n",
    "directory = os.path.expanduser(mydir)\n",
    "test_samples = [os.path.join(directory, p) \n",
    "                for p in os.listdir(directory) if p.endswith('.nii') or p.endswith('.nii.gz')]\n",
    "\n",
    "# 文件模式\n",
    "# test_file = ''\n",
    "# with open(test_file) as f:\n",
    "#     test_samples = [l.strip() for l in f.readlines()]\n",
    "\n",
    "# 自定义模式\n",
    "# test_sampleses = ['path2jpg']\n",
    "test_samples"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4ec8a564",
   "metadata": {},
   "source": [
    "## 确定提取特征\n",
    "\n",
    "通过关键词获取要提取那一层的特征。\n",
    "\n",
    "### 支持的模型名称\n",
    "\n",
    "模型名称替换代码中的 `model_name`变量的值。\n",
    "\n",
    "| **模型系列** | **模型名称**                                                 |\n",
    "| ------------ | ------------------------------------------------------------ |\n",
    "| ResNet       | resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200 |"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef024a4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from onekey_core.core import create_model\n",
    "from onekey_algo.custom.components.comp2 import extract3d, init_from_onekey3d\n",
    "\n",
    "viz_dir = \"你自己训练的3D模型的viz目录。\"\n",
    "model, transformer, device = init_from_onekey3d(viz_dir)\n",
    "\n",
    "for n, m in model.named_modules():\n",
    "    print('Feature name:', n, \"|| Module:\", m)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5c2a22e8",
   "metadata": {},
   "source": [
    "## 提取特征\n",
    "\n",
    "`Feature name:` 之后的名称为要提取的特征名，例如`layer3.0.conv2`, 一般深度学习特征提取最后一层，例如`avgpool`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5b9a01f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from functools import partial\n",
    "from onekey_algo.custom.components.comp2 import extract3d, print_feature_hook, reg_hook_on_module\n",
    "from monai.data import ImageDataset\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "feature_name = 'to_latent'\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "with open('feature.csv', 'w') as outfile:\n",
    "    hook = partial(print_feature_hook, fp=outfile)\n",
    "    find_num = reg_hook_on_module(feature_name, model, hook)\n",
    "    val_ds = ImageDataset(image_files=test_samples, transform=transformer)\n",
    "    # create a validation data loader\n",
    "    val_loader = DataLoader(val_ds, batch_size=1, num_workers=0)\n",
    "    results = extract3d(val_loader, test_samples, model, device, fp=outfile)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1ed292b8",
   "metadata": {},
   "source": [
    "### 读取数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cfd4da6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from onekey_algo.custom.Manager import onekey_show\n",
    "onekey_show('深度学习特征提取|特征读取')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c305dd17",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "features = pd.read_csv('feature.csv', sep=',', header=None)\n",
    "features.columns=['ID'] + [f\"DL_{i}\" for i in list(features.columns[1:])]\n",
    "features.to_csv('feature.csv', index=False)\n",
    "features.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7a65173b",
   "metadata": {},
   "source": [
    "### 深度特征压缩\n",
    "\n",
    "深度学习特征压缩，注意压缩到的维度需要小于样本数\n",
    "\n",
    "```python\n",
    "def compress_df_feature(features: pd.DataFrame, dim: int, not_compress: Union[str, List[str]] = None,\n",
    "                        prefix='') -> pd.DataFrame:\n",
    "    \"\"\"\n",
    "    压缩深度学习特征\n",
    "    Args:\n",
    "        features: 特征DataFrame\n",
    "        dim: 需要压缩到的维度，此值需要小于样本数\n",
    "        not_compress: 不进行压缩的列。\n",
    "        prefix: 所有特征的前缀。\n",
    "\n",
    "    Returns:\n",
    "\n",
    "    \"\"\"\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c49423b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from onekey_algo.custom.components.comp1 import compress_df_feature\n",
    "\n",
    "cm_features = compress_df_feature(features=features, dim=32, prefix='DL_', not_compress='ID')\n",
    "cm_features.to_csv('compress_features.csv', header=True, index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b10614b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
